# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runs the preprocessing job to produce records for training."""

import argparse
import ConfigParser
import logging
import os
import sys

import apache_beam as beam

from preprocessing import preprocess


def _parse_arguments(argv):
  """Parses command line arguments."""
  parser = argparse.ArgumentParser(
      description='Runs preprocessing on census train data.')
  parser.add_argument(
      '--project_id',
      required=True,
      help='Name of the project.')
  parser.add_argument(
      '--job_name',
      required=False,
      help='Name of the dataflow job.')
  parser.add_argument(
      '--job_dir',
      required=True,
      help='Directory to write outputs.')
  parser.add_argument(
      '--cloud',
      default=False,
      action='store_true',
      help='Run preprocessing on the cloud.')
  parser.add_argument(
      '--input_data',
      required=True,
      help='Path to input data.')
  args, _ = parser.parse_known_args(args=argv[1:])
  return args


def _parse_config(env, config_file_path):
  """Parses configuration file.

  Args:
    env: The environment in which the preprocessing job will be run.
    config_file_path: Path to the configuration file to be parsed.

  Returns:
    A dictionary containing the parsed runtime config.
  """
  config = ConfigParser.ConfigParser()
  config.read(config_file_path)
  return dict(config.items(env))


def _set_logging(log_level):
  logging.getLogger().setLevel(getattr(logging, log_level.upper()))


def main():
  """Configures pipeline and spawns preprocessing job."""

  args = _parse_arguments(sys.argv)
  config_path = os.path.abspath(
      os.path.join(__file__, os.pardir, 'preprocessing_config.ini'))
  config = _parse_config('CLOUD' if args.cloud else 'LOCAL',
                         config_path)
  ml_project = args.project_id
  options = {'project': ml_project}

  if args.cloud:
    if not args.job_name:
      raise ValueError('Job name must be specified for cloud runs.')
    options.update({
        'job_name': args.job_name,
        'num_workers': int(config.get('num_workers')),
        'max_num_workers': int(config.get('max_num_workers')),
        'staging_location': os.path.join(args.job_dir, 'staging'),
        'temp_location': os.path.join(args.job_dir, 'tmp'),
        'region': config.get('region'),
        'setup_file': os.path.abspath(
            os.path.join(__file__, '../..', 'dataflow_setup.py')),
        })
  pipeline_options = beam.pipeline.PipelineOptions(flags=[], **options)
  _set_logging(config.get('log_level'))

  with beam.Pipeline(
      config.get('runner'), options=pipeline_options) as pipeline:
    preprocess.run(pipeline, args.input_data, args.job_dir)


if __name__ == '__main__':
  main()
